﻿#include "proxy_handler_impl.hh"
#include "box/service_box.hh"
#include "config/box_config.hh"
#include "lang_impl.hh"
#include "service_finder/service_finder.hh"
#include "service_register/service_register.hh"
#include "util/os_util.hh"
#include "util/string_util.hh"
#include "util/time_util.hh"
#include "util/timer_wheel.hh"

static rpc::TransportPtr NullTransPtr;

template <typename T>
inline static auto copy_header(kratos::service::ServiceBox *box, T &header,
                               rpc::TransportPtr &transport) -> bool {
  if (sizeof(T) !=
      transport->peek(reinterpret_cast<char *>(&header), sizeof(T))) {
    box->write_log(kratos::lang::LangID::LANG_PROXY_INVALID_HEADER,
                   klogger::Logger::WARNING,
                   kratos::util::demangle(typeid(T).name()).c_str());
    // 关闭连接
    transport->close();
    return false;
  }
  header.ntoh();
  return true;
}

rpc::ProxyHandlerImpl::ProxyHandlerImpl(kratos::service::ServiceBox *box) {
  box_ = box;
  timer_wheel_ = kratos::make_unique_pool_ptr<kratos::util::TimerWheel>(box);
  load_config();
}

rpc::ProxyHandlerImpl::~ProxyHandlerImpl() {}

bool rpc::ProxyHandlerImpl::onRelay(TransportPtr &transport,
                                    const RpcMsgHeader &header) {
  // 处理转发
  switch ((RpcMsgType)header.type) {
  case RpcMsgType::RPC_PROXY_CALL: {
    return inside_call_outside(transport, header);
  } break;
  case RpcMsgType::RPC_PROXY_RETURN: {
    return inside_return_outside(transport, header);
  } break;
  case RpcMsgType::RPC_CALL: {
    return outside_call_inside(transport, header);
  } break;
  case RpcMsgType::RPC_RETURN: {
    return outside_return_inside(transport, header);
  } break;
  default:
    break;
  }
  return false;
}

void rpc::ProxyHandlerImpl::update(std::time_t now) {
  // 调用超时
  timer_wheel_->update(now);
}

auto rpc::ProxyHandlerImpl::set_seed(std::uint8_t seed) -> void {
  // 种子用来区分不同的代理，最多一个集群可以存在255个代理
  // 1. 种子用来生成GlobalIndex
  // 2. GlobalIndex为4个字节，其中种子1字节，自增ID占用3字节
  seed_ = seed;
}

auto rpc::ProxyHandlerImpl::on_accept(rpc::TransportPtr transport) -> void {
  // 接受了新的连接，获取一个未被使用的GlobalIndex
  auto global_index = new_global_index();
  if (global_index == rpc::INVALID_GLOBAL_INDEX) {
    // ID空间耗尽拒绝连接
    box_->write_log(kratos::lang::LangID::LANG_PROXY_VIRTUAL_ID_EXHAUSTED,
                    klogger::Logger::FAILURE);
    // 关闭连接
    transport->close();
  } else {
    if (!add_outside_transport(global_index, transport)) {
      // 外部管道添加失败
      box_->write_log(
          kratos::lang::LangID::LANG_PROXY_ADD_OUTSIDE_TRANSPORT_FAILED,
          klogger::Logger::FAILURE);
    }
  }
}

auto rpc::ProxyHandlerImpl::on_close(rpc::TransportPtr transport) -> void {
  remove_outside_transport(transport);
}

auto rpc::ProxyHandlerImpl::new_global_index() -> rpc::GlobalIndex {
  // 从GlobalIndex池内找到未使用的索引并分配
  if (!global_index_pool_.empty()) {
    auto it = global_index_pool_.begin();
    auto index = *it;
    global_index_pool_.erase(it);
    return index;
  }
  // 池内没有则建立一个新的
  rpc::GlobalIndex index = seed_;
  index <<= 24;
  if (serial_ > MAX_SERIAL) {
    // ID空间耗尽
    return rpc::INVALID_GLOBAL_INDEX;
  }
  index |= ++serial_;
  return index;
}

auto rpc::ProxyHandlerImpl::get_service_transport_by_uuid(TransportPtr &from,
                                                          rpc::ServiceUUID uuid)
    -> TransportPtr {
  return get_inside_service_transport(uuid, from->getGlobalIndex());
}

auto rpc::ProxyHandlerImpl::get_inside_service_transport(
    rpc::ServiceUUID uuid, rpc::GlobalIndex global_index) -> TransportPtr {
  //
  // 寻找服务UUID对应的集群内管道
  // 1. 如果缓存内没找到则发起一次服务发现，并将结果缓存
  // 2. 返回缓存内管道
  // TODO 服务失效
  // 当前的实现，服务容器间的连接断开时才会清理相关服务
  //
  ServiceUUIDChannelMap *uuid_channel_map{nullptr};
  auto index_it = global_indexer_map_.find(global_index);
  if (index_it == global_indexer_map_.end()) {
    uuid_channel_map = &global_indexer_map_[global_index];
  } else {
    uuid_channel_map = &index_it->second;
  }
  auto uuid_it = uuid_channel_map->find(uuid);
  if (uuid_it == uuid_channel_map->end()) {
    auto uuid_str = std::to_string(uuid);
    // 管道内未找到,发起一次同步发现,如果缓存失效会导致阻塞query_timeout_毫秒
    // TODO 防止协议破解后的恶意攻击
    auto transport = box_->get_transport_sync(uuid_str, query_timeout_);
    if (!transport) {
      box_->write_log(kratos::lang::LangID::LANG_PROXY_FIND_SERVICE_TIMEOUT,
                      klogger::Logger::FAILURE, uuid_str.c_str());
      // 超时仍未找到
      return NullTransPtr;
    }
    // 找到并记录
    if (global_indexer_map_.find(global_index) != global_indexer_map_.end()) {
      uuid_channel_map->emplace(uuid, transport);
    }
    return transport;
  } else {
    if (uuid_it->second->isClose()) {
      // 服务不可用
      uuid_channel_map->erase(uuid_it);
      // 递归调用，发现新的可用服务，只会递归一次
      return get_inside_service_transport(uuid, global_index);
    } else {
      // 返回缓存内可用服务
      return uuid_it->second;
    }
  }
}

auto rpc::ProxyHandlerImpl::get_global_index(TransportPtr &from,
                                             const RpcMsgHeader &header)
    -> GlobalIndex {
  // 从协议中取出GlobalIndex
  if ((RpcMsgType)header.type == RpcMsgType::RPC_PROXY_CALL) {
    static RpcProxyCallHeader callHeader;
    if (!copy_header<rpc::RpcProxyCallHeader>(box_, callHeader, from)) {
      return INVALID_GLOBAL_INDEX;
    }
    return callHeader.callHeader.globalIndex;
  } else if ((RpcMsgType)header.type == RpcMsgType::RPC_PROXY_RETURN) {
    static RpcProxyRetHeader callRetHeader;
    if (!copy_header<rpc::RpcProxyRetHeader>(box_, callRetHeader, from)) {
      return INVALID_GLOBAL_INDEX;
    }
    return callRetHeader.retHeader.globalIndex;
  }
  return rpc::INVALID_GLOBAL_INDEX;
}

auto rpc::ProxyHandlerImpl::get_inside_service_transport(
    TransportPtr &from, const RpcMsgHeader &header, rpc::CallID &real_call_id)
    -> TransportPtr {
  // 获取集群内服务管道
  if ((RpcMsgType)header.type == RpcMsgType::RPC_CALL) {
    static rpc::RpcCallHeader callHeader;
    if (!copy_header<rpc::RpcCallHeader>(box_, callHeader, from)) {
      return NullTransPtr;
    }
    return get_service_transport_by_uuid(from,
                                         callHeader.callHeader.serviceUUID);
  } else if ((RpcMsgType)header.type == RpcMsgType::RPC_RETURN) {
    // 外部调用返回
    static rpc::RpcRetHeader callRetHeader;
    if (!copy_header<rpc::RpcRetHeader>(box_, callRetHeader, from)) {
      return NullTransPtr;
    }
    // 通过GlobalIndex查找调用信息
    auto it = inside_to_outside_call_info_map_.find(from->getGlobalIndex());
    if (it == inside_to_outside_call_info_map_.end()) {
      return NullTransPtr;
    }
    ServiceUUID service_uuid;
    // 调用已经正常返回，销毁调用信息，查找调用对应的管道
    auto transport = it->second.remove(callRetHeader.retHeader.callID,
                                       service_uuid, real_call_id);
    if (!transport || transport->isClose()) {
      box_->write_log(
          kratos::lang::LangID::LANG_PROXY_NOT_FOUND_SERVICE_FOR_RETURN,
          klogger::Logger::WARNING);
      return NullTransPtr;
    }
    // 获取调用相关的管道
    return transport;
  } else {
    // 协议错误
    from->close();
  }
  return NullTransPtr;
}

auto rpc::ProxyHandlerImpl::service_not_found(
    TransportPtr &from, const rpc::RpcCallHeader &callHeader) -> void {
  // 代替集群内节点发送一条服务（方法）未找到的协议
  static rpc::RpcRetHeader header;
  header.retHeader.callID = callHeader.callHeader.callID;
  header.retHeader.errorID = (rpc::ErrorID)rpc::RpcError::NOT_FOUND;
  header.retHeader.serviceID = 0;
  header.header.length = sizeof(header);
  header.header.type = (rpc::MsgTypeID)rpc::RpcMsgType::RPC_RETURN;
  header.hton();
  from->send(reinterpret_cast<char *>(&header), sizeof(header));
}

auto rpc::ProxyHandlerImpl::record_inside_to_outside_call_info(
    GlobalIndex global_index, TransportPtr &transport,
    rpc::CallID &virtual_call_id) -> bool {
  static rpc::RpcProxyCallHeader callHeader;
  if (!copy_header<rpc::RpcProxyCallHeader>(box_, callHeader, transport)) {
    return false;
  }
  if (callHeader.callHeader.oneWay) {
    // 无返回的协议不需要记录
    return true;
  }
  // 建立一个不会重复的虚拟调用ID
  virtual_call_id = new_virtual_id();
  auto call_it = inside_to_outside_call_info_map_.find(global_index);
  if (call_it != inside_to_outside_call_info_map_.end()) {
    // 插入调用请求
    if (!call_it->second.add(call_timeout_, callHeader.callHeader.callID,
                             callHeader.callHeader.serviceUUID, virtual_call_id,
                             transport)) {
      return false;
    }
  } else {
    // 对外部服务的第一次调用
    auto result = inside_to_outside_call_info_map_.emplace(
        global_index, InsideCallOutsideManager(timer_wheel_.get()));
    if (!result.second) {
      return false;
    }
    // 插入调用请求
    if (!result.first->second.add(call_timeout_, callHeader.callHeader.callID,
                                  callHeader.callHeader.serviceUUID,
                                  virtual_call_id, transport)) {
      return false;
    }
  }
  return true;
}

auto rpc::ProxyHandlerImpl::get_outside_transport(TransportPtr &transport,
                                                  const RpcMsgHeader &header,
                                                  GlobalIndex &global_index)
    -> TransportPtr & {
  // 获取GlobalIndex
  global_index = get_global_index(transport, header);
  if (global_index == rpc::INVALID_GLOBAL_INDEX) {
    return NullTransPtr;
  }
  // 获取外部管道
  auto it = global_index_transport_map_.find(global_index);
  if (it == global_index_transport_map_.end()) {
    return NullTransPtr;
  }
  return it->second;
}

rpc::CallID rpc::ProxyHandlerImpl::new_virtual_id() {
  return proxy_virtual_call_id_++;
}

auto rpc::ProxyHandlerImpl::load_config() -> bool {
  try {
    auto &config = box_->get_config();
    if (config.has_attribute("proxy.query_timeout")) {
      // 服务发现的超时时间，毫秒
      query_timeout_ = config.get_number<std::time_t>("proxy.query_timeout");
    }
    if (config.has_attribute("proxy.call_timeout")) {
      // 转发调用超时，毫秒
      call_timeout_ = config.get_number<std::time_t>("proxy.call_timeout");
    }
    if (config.has_attribute("proxy.services")) {
      //
      // 选取一个listener代替网关内部的服务进行注册
      //
      auto services = config.get_array<std::string>("proxy.services");
      if (!services.empty()) {
        for (const auto &service : services) {
          const auto &listeners =
              config.get_array<std::string>("listener.host");
          if (!listeners.empty()) {
            auto index = kratos::util::get_random_uint32(
                0, (std::uint32_t)listeners.size() - 1);
            auto const &host = listeners[index];
            if (!box_->get_service_register()->register_service(service,
                                                                host)) {
              box_->write_log_line(klogger::Logger::FAILURE,
                                   "[proxy]register service [" + service +
                                       "] failed");
              continue;
            } else {
              box_->write_log_line(klogger::Logger::INFORMATION,
                                   "[proxy]register service [" + service +
                                       "], host[" + host + "]");
            }
          } else {
            box_->write_log_line(klogger::Logger::WARNING,
                                 "[proxy]proxy.listener empty");
            continue;
          }
        }
        //
        // 连接到内部服务发现集群
        //
        if (config.has_attribute("proxy.service_finder")) {
          if (!config.has_attribute("proxy.service_finder.type")) {
            box_->write_log_line(klogger::Logger::WARNING,
                                 "[proxy]proxy.service_finder.type not set");
            return false;
          }
          if (!config.has_attribute("proxy.service_finder.hosts")) {
            box_->write_log_line(klogger::Logger::WARNING,
                                 "[proxy]proxy.service_finder.hosts not set");
            return false;
          }
          auto finder_type = config.get_string("proxy.service_finder.type");
          auto finder_hosts = config.get_string("proxy.service_finder.hosts");
          int finder_timeout = 5000;
          if (config.has_attribute("proxy.service_finder.timeout")) {
            finder_timeout =
                config.get_number<int>("proxy.service_finder.timeout");
          }
          auto service_finder_ptr = kratos::service::getFinder(finder_type);
          // 启动服务发现
          if (!service_finder_ptr->start(finder_hosts, finder_timeout,
                                         config.get_version())) {
            box_->write_log_line(klogger::Logger::FATAL,
                                 "[proxy]Connect service finder type [" +
                                     finder_type + "] hosts[" + finder_hosts +
                                     "] failed");
            return false;
          }
          //
          // 替换掉外部的服务发现
          //
          box_->set_service_finder(service_finder_ptr);
        } else {
          box_->write_log_line(klogger::Logger::WARNING,
                               "[proxy]proxy.service_finder not found");
        }
      }
    }
    return true;
  } catch (std::exception &ex) {
    box_->write_log(kratos::lang::LangID::LANG_PROXY_CONFIG_ERROR,
                    klogger::Logger::EXCEPTION, ex.what());
  }
  return false;
}

auto rpc::ProxyHandlerImpl::inside_call_outside(TransportPtr &transport,
                                                const RpcMsgHeader &header)
    -> bool {
  // 内部代理调用转化为正常调用
  // 1. 获取需要转发的外部管道索引和管道
  // 2. 记录本次调用的信息，并记录超时时间
  // 3. 转发调用到外部管道
  // 获取GlobalIndex和外部管道
  auto global_index = INVALID_GLOBAL_INDEX;
  auto external_trans = get_outside_transport(transport, header, global_index);
  if (!external_trans || external_trans->isClose()) {
    return false;
  }
  rpc::CallID virtual_call_id = INVALID_CALL_ID;
  // 针对调用记录调用信息并记录超时时间
  if (!record_inside_to_outside_call_info(global_index, transport,
                                          virtual_call_id)) {
    return false;
  }
  // 转发
  return box_->get_rpc()->relay(transport, external_trans, header,
                                virtual_call_id);
}

auto rpc::ProxyHandlerImpl::inside_return_outside(TransportPtr &transport,
                                                  const RpcMsgHeader &header)
    -> bool {
  // 内部代理调用返回转化为正常返回
  // 1. 获取需要转发的外部管道索引和管道
  // 2. 转发
  // 获取GlobalIndex和外部管道
  auto global_index = INVALID_GLOBAL_INDEX;
  auto external_trans = get_outside_transport(transport, header, global_index);
  if (!external_trans || external_trans->isClose()) {
    return false;
  }
  // 转发
  return box_->get_rpc()->relay(transport, external_trans, header);
}

auto rpc::ProxyHandlerImpl::outside_call_inside(TransportPtr &transport,
                                                const RpcMsgHeader &header)
    -> bool {
  // 外部正常调用转化为代理调用
  // 1. 获取内部服务的管道
  // 2. 转发
  static rpc::RpcCallHeader callHeader;
  if (!copy_header<rpc::RpcCallHeader>(box_, callHeader, transport)) {
    return false;
  }
  // 获取内部服务管道
  CallID fake_call_id = INVALID_CALL_ID;
  auto service_transport =
      get_inside_service_transport(transport, header, fake_call_id);
  if (!service_transport || service_transport->isClose()) {
    // 获取超时或服务不存在
    service_not_found(transport, callHeader);
    return false;
  }
  // 转发
  return box_->get_rpc()->relay(transport, service_transport, header);
}

auto rpc::ProxyHandlerImpl::outside_return_inside(TransportPtr &transport,
                                                  const RpcMsgHeader &header)
    -> bool {
  // 外部调用返回转发到发起调用的内部服务
  // 1. 获取内部服务的管道
  // 2. 转发
  // 获取内部服务管道
  rpc::CallID real_call_id = INVALID_CALL_ID;
  auto service_transport =
      get_inside_service_transport(transport, header, real_call_id);
  if (!service_transport || service_transport->isClose()) {
    // 内部服务已经不存在了或者连接断开了
    box_->write_log_line(klogger::Logger::INFORMATION, "Call not found");
    return false;
  }
  // 转发
  return box_->get_rpc()->relay(transport, service_transport, header,
                                real_call_id);
}

auto rpc::ProxyHandlerImpl::outside_sub_inside(TransportPtr &transport,
                                               const RpcMsgHeader &header)
    -> bool {
  static rpc::RpcSubHeader subHeader;
  if (!copy_header<rpc::RpcSubHeader>(box_, subHeader, transport)) {
    return false;
  }
  auto uuid_str = std::to_string(subHeader.serviceUUID);
  auto pub_trans = box_->get_transport_sync(uuid_str, query_timeout_);
  if (!pub_trans || pub_trans->isClose()) {
    return false;
  }
  std::string sub_id(subHeader.sub_id, rpc::SUB_ID_LEN);
  // 添加信息
  sub_info_map_[sub_id] = SubInfo{transport, pub_trans};
  // 转发
  return relay(transport, pub_trans, subHeader.header.length);
}

auto rpc::ProxyHandlerImpl::outside_cancel_inside(TransportPtr &transport,
                                                  const RpcMsgHeader &header)
    -> bool {
  static rpc::RpcCancelSubHeader cancelHeader;
  if (!copy_header<rpc::RpcCancelSubHeader>(box_, cancelHeader, transport)) {
    return false;
  }
  std::string sub_id(cancelHeader.sub_id, rpc::SUB_ID_LEN);
  auto it = sub_info_map_.find(sub_id);
  if (it == sub_info_map_.end()) {
    return false;
  }
  // 转发
  auto ret =
      relay(transport, it->second.pub_trans_ptr, cancelHeader.header.length);
  // 删除
  sub_info_map_.erase(it);
  return ret;
}

auto rpc::ProxyHandlerImpl::inside_pub_outside(TransportPtr &transport,
                                               const RpcMsgHeader &header)
    -> bool {
  static rpc::RpcPubHeader pubHeader;
  if (!copy_header<rpc::RpcPubHeader>(box_, pubHeader, transport)) {
    return false;
  }
  std::string sub_id(pubHeader.sub_id, rpc::SUB_ID_LEN);
  auto it = sub_info_map_.find(sub_id);
  if (it == sub_info_map_.end()) {
    return false;
  }
  if (it->second.sub_trans_ptr->isClose()) {
    //
    // 补发cancel到pub方, 取消订阅
    //
    rpc::RpcCancelSubHeader cancelHeader;
    cancelHeader.header.type =
        rpc::MsgTypeID(rpc::RpcMsgType::RPC_EVENT_CANCEL);
    cancelHeader.header.length = sizeof(cancelHeader);
    std::memcpy(cancelHeader.sub_id, pubHeader.sub_id, rpc::SUB_ID_LEN);
    cancelHeader.hton();
    it->second.pub_trans_ptr->send(
        reinterpret_cast<const char *>(&cancelHeader), sizeof(cancelHeader));
    // 订阅方连接断开
    sub_info_map_.erase(it);
    return false;
  }
  // 转发
  return relay(transport, it->second.sub_trans_ptr, pubHeader.header.length);
}

auto rpc::ProxyHandlerImpl::recycle_global_index(GlobalIndex global_index)
    -> bool {
  return global_index_pool_.insert(global_index).second;
}

auto rpc::ProxyHandlerImpl::add_outside_transport(GlobalIndex global_index,
                                                  TransportPtr &transport)
    -> bool {
  // 设置外部管道的GlobalIndex, 作为连接到代理的外部管道在集群内的唯一标识
  transport->setGlobalIndex(global_index);
  // 记录GlobalIndex与外部管道的对应关系
  return global_index_transport_map_.emplace(global_index, transport).second;
}

auto rpc::ProxyHandlerImpl::remove_outside_transport(
    rpc::TransportPtr &transport) -> bool {
  // 外部管道关闭
  // 1. 清理所有与这个管道相关的集群内部到外部未完成的调用信息
  // 2. 清理GlobalIndex与外部管道的对应关系
  // 3. 将GlobalIndex回收
  auto global_index = transport->getGlobalIndex();
  auto call_it = inside_to_outside_call_info_map_.find(global_index);
  if (call_it != inside_to_outside_call_info_map_.end()) {
    // 删除所有与这个外部连接有关联的内部调用信息
    inside_to_outside_call_info_map_.erase(call_it);
  }
  auto global_index_trans_it = global_index_transport_map_.find(global_index);
  if (global_index_trans_it != global_index_transport_map_.end()) {
    // 解除GlobalIndex与外部管道的对应关系
    global_index_transport_map_.erase(global_index_trans_it);
  }
  auto global_indexer_it = global_indexer_map_.find(global_index);
  if (global_indexer_it != global_indexer_map_.end()) {
    // 删除外部连接有关联的内部服务管道
    global_indexer_map_.erase(global_indexer_it);
  }
  // 回收GlobalIndex
  return recycle_global_index(transport->getGlobalIndex());
}

auto rpc::ProxyHandlerImpl::relay(rpc::TransportPtr &from,
                                  rpc::TransportPtr &to, std::size_t length)
    -> bool {
  static auto buffer = std::make_unique<char[]>(rpc::STREAMBUF_SIZE);
  static auto cur_length = std::size_t(rpc::STREAMBUF_SIZE);
  if (length > cur_length) {
    cur_length = length;
    buffer = std::make_unique<char[]>(cur_length);
  }
  if (from->isClose() || to->isClose()) {
    return false;
  }
  from->recv(buffer.get(), int(length));
  to->send(buffer.get(), int(length));
  return true;
}

rpc::ProxyHandlerImpl::InsideCallOutsideManager::InsideCallOutsideManager(
    kratos::util::TimerWheel *timer_wheel) {
  timer_wheel_ = timer_wheel;
}

rpc::ProxyHandlerImpl::InsideCallOutsideManager::~InsideCallOutsideManager() {
  for (auto &[k, v] : call_map_) {
    timer_wheel_->cancel(v.timer_id);
  }
}

auto rpc::ProxyHandlerImpl::InsideCallOutsideManager::add(
    std::time_t call_timeout, rpc::CallID callID, rpc::ServiceUUID service_uuid,
    rpc::CallID virtual_call_id, TransportPtr &transport) -> bool {
  // 启动保底超时定时器
  auto timer_id = timer_wheel_->schedule_once(
      [&](kratos::util::TimerID /*timer_id*/, std::uint64_t vid) -> bool {
        call_map_.erase((CallID)vid);
        return false;
      },
      call_timeout, virtual_call_id);
  call_map_[virtual_call_id] =
      CallInfo{callID, service_uuid, timer_id, transport};
  return true;
}

auto rpc::ProxyHandlerImpl::InsideCallOutsideManager::remove(
    rpc::CallID callID, rpc::ServiceUUID &service_uuid,
    rpc::CallID &real_call_id) -> TransportPtr {
  auto it = call_map_.find(callID);
  if (it == call_map_.end()) {
    return NullTransPtr;
  }
  service_uuid = it->second.service_uuid;
  real_call_id = it->second.callID;
  TransportPtr transport = it->second.transport;
  // 关闭定时器
  timer_wheel_->cancel(it->second.timer_id);
  call_map_.erase(it);
  return transport;
}
